load("//bazel:api.bzl", "mojo_test")

# Special configuration for multimem allreduce test - only runs on SM90+ hardware
mojo_test(
    name = "test_multimem_allreduce.mojo.test",
    size = "large",
    srcs = [
        "test_allreduce.mojo",
        "test_multimem_allreduce.mojo",
    ],
    env = {
        "MODULAR_DEVICE_CONTEXT_BUFFER_CACHE_ONLY": "false",
        "MODULAR_DEVICE_CONTEXT_BUFFER_CACHE_SIZE_PERCENT": "80",
    },
    exec_properties = {
        "test.resources:gpu-memory": "4",
    },
    main = "test_multimem_allreduce.mojo",
    tags = ["gpu"],
    target_compatible_with = select({
        "//:h100_gpu": ["//:has_multi_gpu"],
        "//:b200_gpu": ["//:has_multi_gpu"],
        "//conditions:default": ["@platforms//:incompatible"],
    }),
    deps = [
        "//max:internal_utils",
        "//max:kv_cache",
        "//max:linalg",
        "//max:nn",
        "//max:quantization",
        "@mojo//:stdlib",
    ],
)

# Regular tests for all other files
[
    mojo_test(
        name = src + ".test",
        size = "large",
        srcs = [src],
        env = {
            "MODULAR_DEVICE_CONTEXT_BUFFER_CACHE_ONLY": "false",
            "MODULAR_DEVICE_CONTEXT_BUFFER_CACHE_SIZE_PERCENT": "80",
        },
        exec_properties = {
            "test.resources:gpu-memory": "4",
        },
        tags = ["gpu"],
        target_compatible_with = [
            "//:has_gpu",
            "//:has_multi_gpu",
        ],
        deps = [
            "//max:internal_utils",
            "//max:kv_cache",
            "//max:linalg",
            "//max:nn",
            "//max:quantization",
            "@mojo//:stdlib",
        ],
    )
    for src in glob(
        ["**/*.mojo"],
        exclude = [
            "test_multimem_allreduce.mojo",
        ],
    )
]
